using DataFrames
using QuadGK
using Distributions
using JLD
using Optim
using ForwardDiff
using CSV
using LinearAlgebra
using Random
using DelimitedFiles

clearconsole()

#-------------------------------------------------------------
# include Functions
#-------------------------------------------------------------
#cd("$(pwd())/Dropbox/Heterogeneity/Software/KS_Simulation/")
readDir = "$(pwd())/Functions/"
include(readDir *"logSpline_Procedures.jl");
include(readDir *"fKF.jl");

#-------------------------------------------------------------
# load specification files
#-------------------------------------------------------------
specDir   = "$(pwd())/SpecFiles/"
include(specDir * "/Basespec1.jl")
include(specDir * "/Dataspec3.jl")

#-------------------------------------------------------------
# Load fVAR parameters (posterior draws)
loadDir  = "$(pwd())/data/Para1/";
PHIpmean   = readdlm(loadDir * "fVAR3_SS3_MCMC1_Bpmean.csv", ',', Float64, '\n'; skipstart=1)
SIGMApmean = readdlm(loadDir * "fVAR3_SS3_MCMC1_Bpmean.csv", ',', Float64, '\n'; skipstart=1)

#-------------------------------------------------------------
# Compute log spline density estimation for (N,K) and t
#-------------------------------------------------------------

PhatDensValue = zeros(T, length(xgrid));
PhatDensCoef  = zeros(T, K);
PhatDensNorm  = zeros(T, 1);
PhatlogLike   = zeros(T, 1);
Vinv_all      = zeros(K*T, K);
N_all         = zeros(T, 1);
Period_all    = zeros(T, 1);
N_details     = zeros(T, 4+K);

    for tt = 1:T
        time_init_loop = time_ns()

        # time t data
        densdraws_t     = densdraws[1:N,tt]./K_exact[tt]
        employdraws_t   = employdraws[1:N,tt]
        selecteddraws_t = densdraws_t[employdraws_t.==1]
        N_all[tt]       = length(selecteddraws_t)

        # count observations with knot restriction
        # recall that there are K-1 knots
        N_knots      = zeros(1,K);
        N_knots[1,1] = sum(selecteddraws_t.<=knots[1])
        for kk = 2:(K-1)
            N_knots[1,kk] = sum((selecteddraws_t.>knots[kk-1]) .& (selecteddraws_t.<=knots[kk]))
        end
        N_knots[1,K] = sum(selecteddraws_t.>knots[end])
        println("Number of obs in knot brackets: $(N_knots)")
        println("Max knot:        $(knots[K-1])")

        # compute MLE
        if tt == 1
            alpha_initial = zeros(K)
        else
            alpha_initial = PhatDensCoef[tt-1,:]
        end
        # alpha_initial = zeros(K)

        if TopCodeFlag == 0
            f1(x)  = logspline_obj(selecteddraws_t, x, knots, minimum(xgrid), maximum(xgrid)) # without top-coding
            td     = TwiceDifferentiable(f1, alpha_initial; autodiff  = :forward )
            results_t = try
                optimize(td, alpha_initial, Newton())
            catch
                0 #optimize(f1, alpha_initial, BFGS()) #optimize(td, zeros(K), Newton())
            end
            if results_t == 0
                results_t = try
                    optimize(td, zeros(K), Newton())
                catch
                    optimize(f1, zeros(K), BFGS()) #optimize(td, zeros(K), Newton())
                end
            end

        else
            f2(x)  = logspline_obj_topcode(selecteddraws_t, x, knots, minimum(xgrid), maximum(xgrid)) # with top-coding
            td     = TwiceDifferentiable(f2, alpha_initial; autodiff  = :forward )
            results_t = try
                optimize(td, alpha_initial, Newton())
            catch
                optimize(f2, alpha_initial, BFGS())
            end
        end

        coef_t = results_t.minimizer
        C_topcode = maximum(selecteddraws_t)
        N_max     = sum(selecteddraws_t.==C_topcode)

        if (TopCodeFlag == 0) | (N_max == 1)
            # likelihood w/o top coding
            PhatlogLike[tt] = - N_all[tt]*results_t.minimum
            pi_hat = 0
            println("No top coding")
        else
            # likelihood w top coding
            println("Top coded value is $(C_topcode), number of obs is $(N_max)")
            pi_hat = N_max/N_all[tt]
            PhatlogLike[tt] = - N_all[tt]*results_t.minimum + (N_all[tt]-N_max)*log(1-pi_hat) + N_max*log(pi_hat)
        end


        # results
        PhatDensCoef[tt,:]  = coef_t
        PhatDensNorm[tt,:]  = lnpdfNormalize(coef_t',knots, minimum(xgrid), maximum(xgrid))
        PhatDensValue[tt,:] = pdfEval(xgrid,coef_t,knots,[PhatDensNorm[tt,1]])';
        N_details[tt,:]     = [N_all[tt] N_max pi_hat C_topcode N_knots]

        # compute negative inverse hessian
        # changed type of hessian_sqrtK output to "Symmetric"
        # note that V_t type is also Symmetric

        if (TopCodeFlag == 0) | (N_max == 1)
            # Hessian w/o top coding
            Hess_t = hessian_loglh(PhatDensCoef[tt,:], knots, minimum(xgrid), maximum(xgrid))
        else
            # Hessian w top coding
            Hess_t = (N_all[tt]-N_max)/N_all[tt]*hessian_loglh(PhatDensCoef[tt,:], knots, minimum(xgrid), C_topcode)
        end

        Vinv_t = - Hess_t
        if isposdef(Vinv_t) == false
            Vinv_eig = eigen(Vinv_t)
            Vinv_t = Symmetric(Vinv_eig.vectors*Diagonal(abs.(Vinv_eig.values))*Vinv_eig.vectors')
            println("Flipped neg eigenvalues")
        end
        Vinv_all[K*(tt-1)+1:K*tt,:] = Vinv_t

        println("K = $(K), Period $(tt)")
        println("Vinv_t is positive definite: $(isposdef(Vinv_t))")
        time_loop=signed(time_ns()-time_init_loop)/1000000000
        println("Elapsed Time $(time_loop) seconds")

    end

    # compress coefficients, ~ = PhatDensCoef_mean
    (PhatDensCoef_factor, PhatDensCoef_lambda, PhatDensCoef_mean ) = coefCompress(PhatDensCoef)
    Ktilde = size(PhatDensCoef_factor)[2]
    println("----------------")
    println("Compression Step")
    println("K = $(K), K-tilde = $(Ktilde)")
    println("----------------")
    println("")

    # Goodness of Fit (GoF) is log likelihood
    MDD_GoF   = zeros(T, 1);
    for tt = 1:T
        MDD_GoF[tt] = PhatlogLike[tt]
    end

# save results
sNameFile = "N" * string(N) * "_T"*string(T)
savedir = "$(pwd())/results/KF0_K"*string(K)*"/";
try mkdir(savedir) catch; end
CSV.write(savedir * sNameFile * "_PhatDensValue.csv", DataFrame(PhatDensValue,:auto))
CSV.write(savedir * sNameFile * "_PhatDensCoef.csv", DataFrame(PhatDensCoef,:auto))
CSV.write(savedir * sNameFile * "_PhatDensCoef_factor.csv", DataFrame(PhatDensCoef_factor,:auto))
CSV.write(savedir * sNameFile * "_PhatDensCoef_lambda.csv", DataFrame(PhatDensCoef_lambda,:auto))
CSV.write(savedir * sNameFile * "_PhatDensCoef_mean.csv", DataFrame(PhatDensCoef_mean,:auto))
CSV.write(savedir * sNameFile * "_Vinv_all.csv", DataFrame(Vinv_all,:auto))
CSV.write(savedir * sNameFile * "_N_all.csv", DataFrame(N_all,:auto))
CSV.write(savedir * sNameFile * "_N_details.csv", DataFrame(N_details,:auto))
CSV.write(savedir * sNameFile * "_MDD_GoF.csv", DataFrame(MDD_GoF,:auto))

#-------------------------------------------------------------
# Given the alpha-hats, the aggregate variables, and fVAR parameters,
# we can run the filter and record the likelihood increments and the filtered alphas.
#-------------------------------------------------------------
data    = [ agg_data[1:T,:] PhatDensCoef_factor ]

# Load ME covariance matrices
# Needs to be adjusted for Lambda
H_t = zeros(n_agg+K, n_agg+K, T)
for tt = 1:T
        Vinv_t    = Symmetric(Vinv_all[K*(tt-1)+1:K*tt,:])
        VinvLam_t = PhatDensCoef_lambda*Vinv_t*PhatDensCoef_lambda'*N_all[tt]
        H_t[n_agg+1:end, n_agg+1:end, tt] = inv(Symmetric(VinvLam_t))
end

# Run Kalman Filter
s_filtered, lh_filtered  = KF0(H_t,SIGMApmean,PHIpmean',data,n_agg)
# add X density part
for tt = 2:T
    lh_filtered[tt] = lh_filtered[tt] + MDD_GoF[tt] + 0.5*K*log(2*pi) + 0.5*log(det(H_t[n_agg+1:end, n_agg+1:end, tt]))
end

lh_filtered = [lh_filtered; sum(lh_filtered)]

# save filtered_alphas and likelihood increments
alphas_filtered = s_filtered[:,n_agg+1:end]
sNameFile = "N" * string(N) * "_T"*string(T)
savedir = "$(pwd())/results/KF0_K"*string(K)*"/";
try mkdir(savedir) catch; end
CSV.write(savedir * sNameFile * "_alphas_filtered.csv", DataFrame(alphas_filtered,:auto))
CSV.write(savedir * sNameFile * "_lh_filtered.csv", DataFrame(lh_filtered,:auto))
